{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CrypTen - Training an Encrypted Neural Network across Workers using Plans\n",
    "\n",
    "We will train an encrypted neural network across different PySyft workers (deployed as [Grid Nodes](https://github.com/OpenMined/PyGrid/tree/dev/apps/node)). For this we will be using Plans and we will be using CrypTen as a backend for SMPC. \n",
    "\n",
    "\n",
    "Authors:\n",
    " - George Muraru - Twitter: [@gmuraru](https://twitter.com/georgemuraru)\n",
    " - Ayoub Benaissa - Twitter: [@y0uben11](https://twitter.com/y0uben11)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Overivew\n",
    "* In this tutorial we will use the [MNIST dataset](http://yann.lecun.com/exdb/mnist/)\n",
    "* The features we need for training the network are split accross two workers (we will name them *alice* and *bob*)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "### Download/install needed repos\n",
    "* Clone the [PyGrid repo](https://github.com/OpenMined/PyGrid)\n",
    "  * we need this because *alice* and *bob* are two different Nodes in our network\n",
    "  * install the PyGrid node component using *poetry*\n",
    "\n",
    "### Bring up the PyGridNodes\n",
    "* In the *PyGrid* repo:\n",
    " 1. install *poetry* (```pip install poetry```)\n",
    " 2. go to *apps/nodes*\n",
    " 3. run ```poetry install``` (those steps are also in the README from the PyGrid repo)\n",
    " 4. start *bob* and *alice* using:\n",
    " ```\n",
    " ./run.sh --id alice --port 3000 --start_local_db\n",
    " ./run.sh --id bob --port 3001 --start_local_db\n",
    " ```\n",
    " \n",
    "This will start two workers, *alice* and *bob* and we will connect to them using the port 3000 and 3001.\n",
    "### Dataset preparation\n",
    "* Run the cell bellow to download a script from the CrypTen repository\n",
    "  * It will be used to split the features between the workers\n",
    "  * Each party will get only a subset of features.\n",
    "  * We will use only 100 entries from the dataset\n",
    "  * We will use binary classification (0 vs [1-9] digits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-08-06 01:22:00--  https://raw.githubusercontent.com/facebookresearch/CrypTen/b1466440bde4db3e6e1fcb1740584d35a16eda9e/tutorials/mnist_utils.py\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.112.133\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.112.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 7401 (7.2K) [text/plain]\n",
      "Saving to: ‘mnist_utils.py’\n",
      "\n",
      "mnist_utils.py      100%[===================>]   7.23K  --.-KB/s    in 0.001s  \n",
      "\n",
      "2020-08-06 01:22:01 (4.93 MB/s) - ‘mnist_utils.py’ saved [7401/7401]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget \"https://raw.githubusercontent.com/facebookresearch/CrypTen/b1466440bde4db3e6e1fcb1740584d35a16eda9e/tutorials/mnist_utils.py\" -O \"mnist_utils.py\"\n",
    "!python \"mnist_utils.py\" --option features --reduced 100 --binary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare the ground"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytest\n",
    "import crypten\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from time import time\n",
    "\n",
    "import syft as sy\n",
    "\n",
    "from syft.frameworks.crypten.context import run_multiworkers\n",
    "from syft.grid.clients.data_centric_fl_client import DataCentricFLClient\n",
    "\n",
    "torch.manual_seed(0)\n",
    "torch.set_num_threads(1)\n",
    "hook = sy.TorchHook(torch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Neural network to train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExampleNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ExampleNet, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)\n",
    "        self.fc1 = nn.Linear(16 * 12 * 12, 100)\n",
    "        self.fc2 = nn.Linear(100, 2)\n",
    " \n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = F.relu(out)\n",
    "        out = F.max_pool2d(out, 2)\n",
    "        out = out.view(-1, 16 * 12 * 12)\n",
    "        out = self.fc1(out)\n",
    "        out = F.relu(out)\n",
    "        out = self.fc2(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup the workers and send them the data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to have the two GridNodes workers running."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[%] Connecting to workers ...\n",
      "[+] Connected to workers\n",
      "[%] Sending training data ...\n",
      "[+] Data ready\n"
     ]
    }
   ],
   "source": [
    "# Syft workers\n",
    "print(\"[%] Connecting to workers ...\")\n",
    "alice = DataCentricFLClient(hook, \"ws://localhost:3000\")\n",
    "bob = DataCentricFLClient(hook, \"ws://localhost:3001\")\n",
    "print(\"[+] Connected to workers\")\n",
    "\n",
    "print(\"[%] Sending training data ...\")\n",
    "\n",
    "# Prepare the labels\n",
    "label_eye = torch.eye(2)\n",
    "labels = torch.load(\"/tmp/train_labels.pth\")\n",
    "labels = labels.long()\n",
    "labels_one_hot = label_eye[labels]\n",
    "\n",
    "# Prepare and send training data\n",
    "alice_train = torch.load(\"/tmp/alice_train.pth\").tag(\"alice_train\")\n",
    "alice_ptr = alice_train.send(alice)\n",
    "bob_train = torch.load(\"/tmp/bob_train.pth\").tag(\"bob_train\")\n",
    "bob_ptr = bob_train.send(bob)\n",
    "\n",
    "print(\"[+] Data ready\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check the data shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One entry from the MNIST dataset contains 28x28 features. Those are splitted accross our workers.\n",
    "\n",
    "We can check it out by running the next cell!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Alice data shape torch.Size([100, 28, 20])\n",
      "Bob data shape torch.Size([100, 28, 8])\n"
     ]
    }
   ],
   "source": [
    "print(f\"Alice data shape {alice_train.shape}\")\n",
    "print(f\"Bob data shape {bob_train.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize a dummy model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instanciate a model and create a dummy input that could be forwarded through it. This is needed to build the CrypTen model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dummy_input = torch.empty(1, 1, 28, 28)\n",
    "pytorch_model = ExampleNet()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the CrypTen computation\n",
    "\n",
    "We need to specify for the ```run_multiworkers``` decorator:\n",
    "* the workers that will take part in the computation\n",
    "* the master address, this will be used for their synchronization\n",
    "* the instantiated model that will be sent\n",
    "* a dummy input for the model\n",
    "\n",
    "We will use the ```func2plan``` decorator to:\n",
    "* trace the operations from our function\n",
    "* sending the plan operations to *alice* and *bob* - the plans operations will act as the function\n",
    "* run the plans operations on both workers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ALICE = 0 # Alice rank in CrypTen\n",
    "BOB = 1 # Bob rank in CrypTen\n",
    "\n",
    "@run_multiworkers(\n",
    "    [alice, bob], master_addr=\"127.0.0.1\", model=pytorch_model, dummy_input=dummy_input\n",
    ")\n",
    "@sy.func2plan()\n",
    "def run_encrypted_training(\n",
    "    model=None,\n",
    "    learning_rate=0.001,\n",
    "    num_epochs=2,\n",
    "    batch_size=10,\n",
    "    num_batches=bob_ptr.shape[0]//10,\n",
    "    labels_one_hot=labels_one_hot,\n",
    "    crypten=crypten,\n",
    "    torch=torch,\n",
    "):\n",
    "    x_alice_enc = crypten.load(\"alice_train\", ALICE)\n",
    "    x_bob_enc = crypten.load(\"bob_train\", BOB)\n",
    "\n",
    "    x_combined_enc = crypten.cat([x_alice_enc, x_bob_enc], dim=2)\n",
    "    x_combined_enc = x_combined_enc.unsqueeze(1)\n",
    "\n",
    "    model.encrypt()\n",
    "    model.train()\n",
    "    loss = crypten.nn.MSELoss()\n",
    "\n",
    "    l_values = []\n",
    "\n",
    "    for i in range(num_epochs):\n",
    "        for batch in range(num_batches):\n",
    "            start, end = batch * batch_size, (batch + 1) * batch_size\n",
    "\n",
    "            x_train = x_combined_enc[start:end]\n",
    "            y_batch = labels_one_hot[start:end]\n",
    "            y_train = crypten.cryptensor(y_batch, requires_grad=True)\n",
    "\n",
    "            # perform forward pass:\n",
    "            output = model(x_train)\n",
    "            loss_value = loss(output, y_train)\n",
    "\n",
    "            # set gradients to \"zero\"\n",
    "            model.zero_grad()\n",
    "\n",
    "            # perform backward pass:\n",
    "            loss_value.backward()\n",
    "\n",
    "            # update parameters\n",
    "            model.update_parameters(learning_rate)\n",
    "\n",
    "            # Print progress every batch:\n",
    "            batch_loss = loss_value.get_plain_text()\n",
    "            l_values.append(batch_loss)\n",
    "\n",
    "    model.decrypt()\n",
    "    return (l_values, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the CrypTen computation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's run the computation defined above"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[%] Starting computation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[+] run_encrypted_training() took 70s\n",
      "Epoch 0:\n",
      "\tBatch 1 of 10 Loss: 0.4639\n",
      "\tBatch 2 of 10 Loss: 0.4665\n",
      "\tBatch 3 of 10 Loss: 0.4063\n",
      "\tBatch 4 of 10 Loss: 0.3487\n",
      "\tBatch 5 of 10 Loss: 0.3313\n",
      "\tBatch 6 of 10 Loss: 0.2796\n",
      "\tBatch 7 of 10 Loss: 0.2768\n",
      "\tBatch 8 of 10 Loss: 0.2432\n",
      "\tBatch 9 of 10 Loss: 0.2457\n",
      "\tBatch 10 of 10 Loss: 0.2003\n",
      "Epoch 1:\n",
      "\tBatch 1 of 10 Loss: 0.1624\n",
      "\tBatch 2 of 10 Loss: 0.1517\n",
      "\tBatch 3 of 10 Loss: 0.1551\n",
      "\tBatch 4 of 10 Loss: 0.1923\n",
      "\tBatch 5 of 10 Loss: 0.1320\n",
      "\tBatch 6 of 10 Loss: 0.1636\n",
      "\tBatch 7 of 10 Loss: 0.2245\n",
      "\tBatch 8 of 10 Loss: 0.1454\n",
      "\tBatch 9 of 10 Loss: 0.1718\n",
      "\tBatch 10 of 10 Loss: 0.1335\n"
     ]
    }
   ],
   "source": [
    "# Get the returned values\n",
    "# key 0 - return values for alice\n",
    "# key 1 - return values for bob\n",
    "print(\"[%] Starting computation\")\n",
    "func_ts = time()\n",
    "*losses, model = run_encrypted_training()[0]\n",
    "func_te = time()\n",
    "print(f\"[+] run_encrypted_training() took {int(func_te - func_ts)}s\")\n",
    "\n",
    "losses_per_epoch = len(losses) // 2\n",
    "\n",
    "for i in range(2):\n",
    "    print(f\"Epoch {i}:\")\n",
    "    for batch, loss in enumerate(losses[i * losses_per_epoch:(i+1) * losses_per_epoch]):\n",
    "        print(f\"\\tBatch {(batch+1)} of 10 Loss: {loss:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model returned is a CrypTen model, but we can always run the usual PySyft methods to share the parameters and so on, as far as the model in not encrypted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph unencrypted module\n",
      "(Wrapper)>FixedPrecisionTensor>[AdditiveSharingTensor]\n",
      "\t-> [PointerTensor | me:36251999356 -> alice:42758419602]\n",
      "\t-> [PointerTensor | me:60399685431 -> bob:89891973551]\n",
      "\t*crypto provider: cp*\n"
     ]
    }
   ],
   "source": [
    "cp = sy.VirtualWorker(hook=hook, id=\"cp\")\n",
    "model.fix_prec()\n",
    "model.share(alice, bob, crypto_provider=cp)\n",
    "print(model)\n",
    "print(list(model.parameters())[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CleanUp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CleanUp portion taken from the CrypTen project\n",
    "\n",
    "import os\n",
    "\n",
    "filenames = ['/tmp/alice_train.pth', \n",
    "             '/tmp/bob_train.pth', \n",
    "             '/tmp/alice_test.pth',\n",
    "             '/tmp/bob_test.pth', \n",
    "             '/tmp/train_labels.pth',\n",
    "             '/tmp/test_labels.pth',\n",
    "             'mnist_utils.py']\n",
    "\n",
    "for fn in filenames:\n",
    "    if os.path.exists(fn): os.remove(fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for \"Projects\". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more \"one off\" mini-projects by searching for GitHub issues marked \"good first issue\".\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+first+issue+%3Amortar_board%3A%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
